{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as dsets\n",
    "from torch.autograd import Variable\n",
    "from torch.nn import Parameter\n",
    "from torch import Tensor\n",
    "import torch.nn.functional as F\n",
    "from util.loader import KddData\n",
    "from origin.lstm import LSTMCell,LSTMModel\n",
    "\n",
    "import math\n",
    "\n",
    "cuda = True if torch.cuda.is_available() else False\n",
    "input_dir = 'data/'\n",
    "input_dim = 10\n",
    "seq_dim = 10\n",
    "num_class = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:1805: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n",
      "  warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n",
      "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:1794: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n",
      "  warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: tensor(0.6819)\n",
      " Recall: 0.5836828351020813, precision: 0.7927520275115967, f1score: 0.5453807711601257\n"
     ]
    }
   ],
   "source": [
    "import torchmetrics\n",
    "\n",
    "def runmodel(dataset, model):\n",
    "    accuracy = torchmetrics.Accuracy()\n",
    "    recall = torchmetrics.Recall(average='macro', num_classes=num_class)\n",
    "    precision = torchmetrics.Precision(average='macro', num_classes=num_class)\n",
    "    f1score = torchmetrics.F1(average='macro', num_classes=num_class)\n",
    "    eval_loss = 0\n",
    "    eval_acc = 0\n",
    "\n",
    "    test_loader = dataset.test_dataloader\n",
    "    for images, labels in test_loader:\n",
    "        if torch.cuda.is_available():\n",
    "            images = Variable(images.view(-1, seq_dim, input_dim).cuda())\n",
    "        else:\n",
    "            images = Variable(images.view(-1 , seq_dim, input_dim))\n",
    "        # Forward pass only to get logits/output\n",
    "        outputs = model(images)\n",
    "\n",
    "        # Get predictions from the maximum value\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        \n",
    "        \n",
    "        labels = labels.cpu()\n",
    "        predicted = predicted.cpu()\n",
    "        accuracy(predicted, labels)\n",
    "        recall(predicted, labels)\n",
    "        precision(predicted, labels)\n",
    "        f1score(predicted, labels)\n",
    "\n",
    "        acc = accuracy.compute().data.detach()\n",
    "    print('Accuracy:',acc)\n",
    "\n",
    "    rec = recall.compute().data.detach()\n",
    "    prec = precision.compute().data.detach()\n",
    "    f1sc = f1score.compute().data.detach()\n",
    "    print(f\" Recall: {rec}, precision: {prec}, f1score: {f1sc}\")\n",
    "\n",
    "\n",
    "testFile  = input_dir + 'test-wired-bin.csv'\n",
    "testFile1 = input_dir + 'test-kr-bin.csv'\n",
    "testFile2 = input_dir + 'test-us-bin.csv'\n",
    "testFile3 = input_dir + 'test-wifi-bin.csv'\n",
    "testFile4 = input_dir + 'test-4G-bin.csv'\n",
    "testFile5 = input_dir + 'test-3G-bin.csv'\n",
    "\n",
    "batch_size = 64\n",
    "\n",
    "model = torch.load('origin/lstm.pth')\n",
    "dataset = KddData(batch_size, testFile4)\n",
    "runmodel(dataset, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
